import numpy as np
import pandas as pd
from impyute.imputation.cs import mice
from scipy import stats


# 均值插补
def mean_interpolate(df, feature_columns):
    df = df.replace(np.nan, np.nan)
    for col in feature_columns:
        if col in list(df.select_dtypes(include=['int', 'float'])):
            df[col] = df[col].fillna(df[col].mean()).astype(float)
    return df


# 众数插补
def most_frequent_interpolate(df, feature_columns):
    df = df.replace(np.nan, np.nan)
    for col in feature_columns:
        if col in list(df.select_dtypes(include=['int', 'float'])):
            df[col] = df[col].fillna(df[col].mode().iloc[0]).astype(float)
    return df


# 中位数插补
def median_interpolate(df, feature_columns):
    df = df.replace(np.nan, np.nan)
    for col in feature_columns:
        if col in list(df.select_dtypes(include=['int', 'float'])):
            df[col] = df[col].fillna(df[col].median()).astype(float)
    return df


# 线性插值
def liner_interpolate(df, feature_columns):
    df = df.replace(np.nan, np.nan)
    for col in feature_columns:
        if col in list(df.select_dtypes(include=['int', 'float'])):
            df[col] = df[col].fillna(df[col].interpolate(method='linear', limit_direction='both')).astype(float)
    return df


# 多重插补
def mice_fill(df, feature_columns):
    df = df.replace(np.nan, np.nan)
    numerical_list = list(df.select_dtypes(include=['int', 'float']))
    result = set(feature_columns) < set(numerical_list)
    if len(feature_columns) > 1 and result:
        df_fill = df.loc[:, feature_columns]
        df_fill = pd.DataFrame(mice(df_fill.values))
        df_fill.columns = feature_columns
        for col in df.columns:
            if col in df_fill.columns:
                df[col] = df_fill[col].astype(float)
    return df


# 同期插补
def period_fill(df, feature_columns, period):
    df = df.replace(np.nan, np.nan)
    length = len(df)
    for col in feature_columns:
        data_list = list(df[col])
        nan_index = np.argwhere(np.isnan(data_list))
        nan_index_list = []
        # 整合缺失值index列表
        for n in list(nan_index):
            nan_index_list.append(n[0])
        # 在每个column里面找缺失值
        if len(nan_index_list) != 0:
            for i in nan_index_list:
                sum_index = []
                # 在缺失值之前取index
                before = int(i) // int((period + 1))
                if int(before) != 0:
                    for j in range(before):
                        sum_index.append(i - (j + 1) * (period + 1))
                # 在缺失值之后取index
                after = (length - 1 - i) // (period + 1)
                if int(after) != 0:
                    for k in range(after):
                        sum_index.append(i + (k + 1) * (period + 1))
                period_sum = 0
                final_index = list(set(sum_index).difference(set(nan_index_list)))
                # 计算同期平均值并插补
                for index in final_index:
                    period_sum += data_list[index]
                if len(final_index) == 0:
                    return "请选择更短时间间隔！"
                average = period_sum / (len(final_index))
                data_list[i] = average
        df_temp = pd.DataFrame(data_list).astype(float)
        df[col] = df_temp
    return df


# 返回需要高亮显示的数据
def fillsDic(df, df_fill, cols):
    null_list = []
    for col in cols:
        ifnull = df[col].isnull().any()
        if ifnull == True:
            null_list.append(col)
    dic = {}
    for col in null_list:
        data1 = df.to_dict(orient='list')[col]
        data2 = df_fill.to_dict(orient='list')[col]
        fill_list = []
        for i in range(len(data2)):
            if data2[i] != data1[i]:
                fill_list.append(data2[i])
        dic[col] = list(set(fill_list))
    return dic


# 判断是否至少有一列特征列有缺失值。
def checkNull(df, cols):
    check = []
    for col in cols:
        if df[col].isnull().any():
            check.append(1)
    if len(check) == 0:
        print("没有含有缺失值的特征列，请重新选择特征列!")
        return 0
    return 1


def not_null_col(df, cols):
    not_null = []
    for col in cols:
        if not df[col].isnull().any():
            not_null.append(col)
    return not_null


# if __name__ == '__main__':
#     df = pd.read_excel('MissingData.xlsx', engine='openpyxl')
#     cols = ['电气机械和器材制造业', '房地产业']
#     print(not_null_col(df, cols))

    # df_fill = period_fill(df[:8], feature_columns, 6).round(0)  # 自测更换function
    # print(fillsDic(df, df_fill, feature_columns))
    # result = {}
    # non_steady_col = ['feature1', 'feature2']
    # result['not_norm'] = str('提示：检测到以下特征序列不平稳:',non_steady_col)
    # print(result)
    # print(df_fill)
